""" hard-coded version of solve_for_unknowns function """
# find packages for scipy multi solvers

#=
using Optim # multivariate solvers
using NLsolve # multivariate solvers
using Roots # univariate solvers
using JSOSolvers # multivariate solvers - df-sane
# using LsqFit # multivariate solvers - lm

function solve_for_unknowns(residual, unknowns, solver, solver_kwargs; residual_kwargs=nothing, constrained_method="linear_continuation", constrained_kwargs=nothing, tol=2e-12, verbose=False)

    if residual_kwargs === nothing
        residual_kwargs = Dict()
    end

    # scipy univariate and multivariate solvers
    optimize_uni_solvers = ["bisect", "brentq", "brenth", "ridder", "toms748", "newton", "secant", "halley"] # functions done
    # optimize_multi_solvers = ["hybr" # = nlsolve(function, x)
    #                         , "lm" # = nlsolve(function, x, method = :lm) # may need lsqfit package
    #                         , "broyden1" # = nlsolve(function, x, method = :broyden)
    #                         , "broyden2" # = nlsolve(function, x, method = :broyden)
    #                         , "anderson" # = nlsolve(function, x, method = :anderson)
    #                         , "linearmixing" # = nlsolve(function, x, method = :fixedpoint, m=1)
    #                         , "diagbroyden" # = nlsolve(function,x, method = :broyden)
    #                         , "excitingmixing" # = nlsolve(function, x, method = :fixedpoint, m=1, alpha=0.5) # may need to check if this is a unique method?
    #                         , "krylov" # = # nlsolve(function, x, method = :broyden) # I think this is broyden but need to check
    #                         , "df-sane" # = df_solver(function, x) # do we need this
    #                          ]

    optimize_multi_solvers = ["lm", "broyden", "anderson", "fixedpoint"]

    solve_for_unknowns(f, x, method = broyden)
    if solver in optimize_multi_solvers
        return nlsolve(f, x, method=:solver)
    end

    # wrap kwargs into the resid function
    residual_f = (x -> residual(x; residual_kwargs...))

    if solver === nothing
        throw(RuntimeError("Must provide a numerical solver from the following set: brentq, broyden, solved"))
    elseif solver in scipy_optimize_uni_solvers
        initial_values_or_bounds = extract_univariate_initial_values_or_bounds(unknowns)
        result = findzero(residual_f, initial_values_or_bounds, method = :solver; atol=tol, solver_kwargs...)
        if result === nothing
            throw(ValueError("Steady state solver, $solver, did not converge."))
        end
        unknown_solutions = result # unsure if this is correct; py is result.root
    elseif solver in scipy_optimize_multiple_solvers
        initial_values, bounds = extract_multivariate_initial_values_and_bounds(unknowns)
        # If no bounds provided
        if bounds === nothing
            result = nlsolve(residual_f, initial_values; tol=tol, solver_kwargs...)
        else
            constrained_residual = constrained_multivariate_residual(residual_f, residual_f, bounds, verbose=verbose, method=constrained_method, constrained_kwargs...)
            result = nlsolve(constrained_residual, initial_values, method = :solver; tol=tol, solver_kwargs...)
        end
        if !result.converged
            throw(ValueError("Steady-state solver, $solver, did not converge. The termination status is $(result.fails) with code $result.code)."))
        end
        unknown_solutions = result.zero
    end
    return Dict(zip(keys(unknowns), unknown_solutions))
end
# stopped before implementing custom solvers; may need to do this later
#
#
#
#
""" hopefully somewhat less hardcoded version of solve for unknowns using multiple dispatch """
# helper functions extract_univariate_initial_values_or_bounds and extract_multivariate_initial_values_and_bounds are not necessarily needed anymore here (if I did this right)

using Optim # multivariate solvers
using NLsolve # multivariate solvers
using Roots # univariate solvers

function solve_for_unknowns(residual, unknowns, solver, solver_kwargs...;
                            residual_kwargs=Dict(), constrained_method="linear_continuation",
                            constrained_kwargs=Dict(), tol=2e-12, verbose=false)

    # Wrap kwargs into the residual function
    residual_f = x -> residual(x; residual_kwargs...)

    if solver === nothing
        error("Must provide a numerical solver.")
    end

    if length(unknowns) == 1
        # Univariate solver
        initial_value_or_bounds = first(collect(unknowns))  # Assuming it's a single value or bounds
        result = find_zero(residual_f, initial_value_or_bounds; atol=tol, solver_kwargs...)
        if result === nothing
            error("Steady-state solver, $solver, did not converge.")
        end
        unknown_solutions = result
    else
        # Multivariate solver
        initial_values = collect(values(unknowns))
        result = nlsolve(residual_f, initial_values; xtol=tol, solver_kwargs...)
        if !result.converged
            error("Steady-state solver, $solver, did not converge. The termination status is $(result.fails) with code $(result.code).")
        end
        unknown_solutions = result.zero
    end
    return Dict(zip(keys(unknowns), unknown_solutions))
end
=#
function instantiate_steady_state_mutable_kwargs(dissolve, block_kwargs, solver_kwargs, constrained_kwargs)
    if isa(dissolve, Nothing)
        dissolve = Vector{Any}()
    end
    if isa(block_kwargs, Nothing)
        block_kwargs = Dict()
    end
    if isa(solver_kwargs, Nothing)
        solver_kwargs = Dict()
    end
    if isa(constrained_kwargs, Nothing)
        constrained_kwargs = Dict()
    end
    return dissolve, block_kwargs, solver_kwargs, constrained_kwargs
end

function provide_solver_default(unknowns)
    if length(unknowns) == 1
        bounds = first(values(unknowns))
        if !(isa(bounds, Tuple)) || bounds[1] > bounds[2]
            error("Unable to find a compatible one-dimensional solver with provided unknowns. Please provide valid lower/upper bounds, e.g. unknowns = {'a' => (0,1)}")
        else
            return "brentq"
        end
    elseif length(unknowns) > 1
        init_values = values(unknowns)
        if !all(isa(v, Real) for v in  init_values)
            error("Unable to find a compatible multi-dimensional solver with provided unknowns. Please provide valid initial values, e.g. unknowns = {'a': 1, 'b': 2}")
        else
            return "broyden_custom"
        end
    else
        error("unknowns is empty! Please provide a dict of keys/values equal to the number of unknowns that need to be solved for.")
    end
end

function run_consistency_check(cresid, ctol = 1e-9, fragile = false)
    if cresid > ctol
        if fragile
            error("The target values evaluated for the proposed set of unknowns produce a maximum residual value of '$(cresid)', which is greater than the ctol '$(ctol)'. If used, check if HelperBlocks are indeed compatible with the DAG. If this is not an issue, adjust ctol accordingly.")
        else
            @warn println("The target values evaluated for the proposed set of unknowns produce a maximum residual value of '$(cresid)', which is greater than the ctol '$(ctol)'. If used, check if HelperBlocks are indeed compatible with the DAG. If this is not an issue, adjust ctol accordingly.")
        end
    end
end

function compute_target_values(targets, potential_args)
    target_values = zeros(length(targets))
    targets_vec = isa(targets, Dict) ? keys(targets) : targets
    for (i, t) in enumerate(targets_vec)
        v = targets isa Dict ? targets[t] : 0
        if v isa AbstractString
            target_values[i] = potential_args[t] - potential_args[v]
        else
            target_values[i] = potential_args[t] - v
        end
    end

    if length(targets) == 1
        return target_values[1]
    else
        return target_values
    end
end


function compare_steady_states(ss_ref, ss_comp, tol = 1e-8, name_map = nothing, internal = true, check_same_keys = true, verbose = false)
    if isa(name_map, nothing)
        name_map = Dict()
    end

    valid = true
    if internal
        if !hasproperty(ss_ref, :internal) || !hasproperty(ss_comp, :internal)
            @warn("The provided steady state dicts do not both have .internal attrs. Will only compare top-level values.")
            ds_to_check = [(ss_ref, ss_comp, "toplevel")]
        else
            ds_to_check = [(ss_ref, ss_comp, "toplevel")]
            for i in eachindex(ss_ref.internal)
                push!(ds_to_check, (ss_ref.internal[i], ss_comp.internal[i], string(i) * "_internal"))
            end
        end
    else
        ds_to_check = [(ss_ref, ss_comp, "toplevel")]
    end

    for ds in ds_to_check
        d_ref, d_comp, level = ds
        for key_ref in keys(d_ref)
            if key_ref in keys(d_comp)
                key_comp = key_ref
            elseif key_ref in name_map
                key_comp = name_map[key_ref]
            end

            if isscalar(d_ref[key_ref])
                resid = abs(d_ref[key_ref] - d_comp[key_comp])
            else
                resid = norm(d_ref[key_ref] - d_comp[key_comp], Inf)
            end
            if verbose
                println("'$(key_ref)' resid: '$(resid)'")
            else
                if !all(isclose.(resid, 0., atol=tol))
                    valid = false
                end
            end

            if check_same_keys
                d_ref_incl_mapped = Set(keys(d_ref)) - Set(keys(name_map))
                d_comp_incl_mapped = Set(keys(d_comp)) - Set(values(name_map))
                diff_keys = symdiff(d_ref_incl_mapped, d_comp_incl_mapped)
                if !isempty(diff_keys)
                    if verbose
                        println("At level '$(level)', the keys present only one of the two steady state dicts are '$(diff_keys)'")
                        valid = false
                    end
                end
            end
        end
    end
    return valid
end

# used to check for scalar type in multiple places
isscalar(x) = isa(x, Union{Number,AbstractString,Char,Bool})

function extract_univariate_initial_values_or_bounds(unknowns)
    val = first(values(unknowns))
    if isscalar(val)
        return Dict(["x0" => val])
    else
        return Dict(["bracket" => (val[1], val[2])])
    end
end

function extract_multivariate_initial_values_and_bounds(unkowns, fragile = false)
    initial_values = Vector{Any}()
    multi_bounds = Dict()
    for (k, v) in pairs(unknown)
        if isscalar(v)
            push!(initial_values, v)
        elseif length(v) == 2
            if fragile
                error("'$(length(v))' is an invalid size for the value of an unknown. The values of unknowns must either be a scalar, pertaining to a single initial value for the root solver to begin from, a length 2 tuple, pertaining to a lower bound and an upper bound, or a length 3 tuple, pertaining to a lower bound, initial value, and upper bound.")
            else
                @warn println("Interesting values of `unknowns` from length 2 tuple as lower and upper bounds and averaging them to get a scalar initial value to provide to the solver.")
                push!(initial_values, (v[1]+v[2])/2)
            end
        else
            error("'$(length(v))' is an invalid size for the value of an unknown. The values of `unknowns` must either be a scalar, pertaining to a single initial value for the root solver to begin from, a length 2 tuple, pertaining to a lower bound and an upper bound, or a length 3 tuple, pertaining to a lower bound, initial value, and upper bound.")
        end
    end
    return initial_values, multi_bounds
end

function residual_with_linear_continuation(residual, bounds, eval_at_boundary = false, boundary_epsilon = 1e-4, penalty_scale = 1e1, verbose = false)
    lbs = [getindex(v, 1) for v in values(bounds)]
    ubs = [getindex(v, 2) for v in values(bounds)]

    function constr_residual(x, residual_cache = Ref(NaN))
        if eval_at_boundary
            x_censored .= min.(max.(x, lbs), ubs)
        else
            x_censored = x
            for i in eachindex(x)
                if x[i] < lbs[i]
                    x_censored[i] = lbs[i] + boundary_epsilon
                elseif x[i] > ubs[i]
                    x_censored[i] = ubs[i] - boundary_epsilon
                end
            end
        end
        residual_censored = residual(x_censored)
        if verbose
            println("Attempted x is '$(x)'")
            println("Censored x is '$(x)'")
            println("The residual_censored is '$(residual_censored)'")
        end
        if any(isnan.(residual_censored))
            residual_censored = residual_cache[1] .* penalty_scale
            if verbose
               println("The new residual_censored is '$(residual_censored)'")
            end
        else
            if isempty(residual_cache)
               merge!(residual_cache, residual_censored)
            else
                residual_cache[1] = residual_censored
            end

            if verbose
                println("The residual cache is '$(residual_cache[1])'")
            end
        end

        residual_with_boundary_penalty = residual_censored .+ (x .- x_censored) .* penalty_scale .* residual_censored
        return residual_with_boundary_penalty
    end
    return constr_residual
end

function constrained_multivariate_residual(residual, bounds, method = "linear_continuation", verbose = false, constrained_kwargs...)
    if method == "linear_continuation"
        return residual_with_linear_continuation(residual, bounds, verbose = verbose, constrained_kwargs...)
    else
        error("Method '$(method)' for constrained multivariate root-finding has not yet been implemented.")
    end
end
